{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc1311d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f73a3813",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot single variable trend\n",
    "model_name = 'PatchTST'\n",
    "dataset_name = 'weather'\n",
    "\n",
    "output_dir = f'../outputs/{model_name}/{dataset_name}'\n",
    "preds_path = os.path.join(output_dir, f'preds_{dataset_name}_inv.npy')\n",
    "trues_path = os.path.join(output_dir, f'trues_{dataset_name}_inv.npy')\n",
    "\n",
    "preds = np.load(preds_path)\n",
    "trues = np.load(trues_path)\n",
    "print(f'Pred shape: {preds.shape}, True shape: {trues.shape}')\n",
    "\n",
    "sample_idx = 0   # Index of the sample to plot, can be any int or list, e.g. [0, 10, 100]\n",
    "step_range = None  # Step range, e.g. (0, 24) to plot the first 24 steps, None for all\n",
    "\n",
    "def plot_single_var(preds, trues, sample_idx=0, step_range=None):\n",
    "    if isinstance(sample_idx, int):\n",
    "        sample_idx = [sample_idx]\n",
    "    for idx in sample_idx:\n",
    "        y_pred = preds[idx]\n",
    "        y_true = trues[idx]\n",
    "        if step_range is not None:\n",
    "            s, e = step_range\n",
    "            y_pred = y_pred[s:e]\n",
    "            y_true = y_true[s:e]\n",
    "            steps = np.arange(s, e)\n",
    "        else:\n",
    "            steps = np.arange(len(y_pred))\n",
    "        plt.figure(figsize=(8, 4))\n",
    "        plt.plot(steps, y_true, label='True', marker='o')\n",
    "        plt.plot(steps, y_pred, label='Pred', marker='x')\n",
    "        plt.title(f'Sample {idx}')\n",
    "        plt.xlabel('Step')\n",
    "        plt.ylabel('Value')\n",
    "        plt.legend()\n",
    "        plt.grid(True, alpha=0.3)\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "plot_single_var(preds, trues, sample_idx=sample_idx, step_range=step_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89504010",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_forecast(preds, trues, var_indices, sample_idx=0, var_names=None, figsize=(10, 5)):\n",
    "    \"\"\"\n",
    "    Plot prediction vs. ground truth for specified variables and sample.\n",
    "    Args:\n",
    "        preds: np.ndarray, predictions, shape [num_samples, pred_len, num_vars]\n",
    "        trues: np.ndarray, ground truth, same shape as preds\n",
    "        var_indices: int or list[int], indices of variables to plot\n",
    "        sample_idx: int, which sample to plot\n",
    "        var_names: list[str], variable names (optional)\n",
    "        figsize: tuple, figure size\n",
    "    \"\"\"\n",
    "    if isinstance(var_indices, int):\n",
    "        var_indices = [var_indices]\n",
    "    plt.figure(figsize=figsize)\n",
    "    for i, var_idx in enumerate(var_indices):\n",
    "        pred = preds[sample_idx, :, var_idx]\n",
    "        true = trues[sample_idx, :, var_idx]\n",
    "        label_pred = f\"Predicted - Var {var_idx}\" if not var_names else f\"Predicted - {var_names[var_idx]}\"\n",
    "        label_true = f\"True - Var {var_idx}\" if not var_names else f\"True - {var_names[var_idx]}\"\n",
    "        plt.plot(pred, label=label_pred, linestyle='--')\n",
    "        plt.plot(true, label=label_true, linestyle='-')\n",
    "    plt.xlabel(\"Forecast Step\")\n",
    "    plt.ylabel(\"Value\")\n",
    "    plt.title(f\"Sample {sample_idx} Forecast Comparison\")\n",
    "    plt.legend()\n",
    "    plt.grid(True)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "preds = np.load(\"../outputs/PatchTST/weather/preds_weather_inv.npy\")\n",
    "trues = np.load(\"../outputs/PatchTST/weather/trues_weather_inv.npy\")\n",
    "\n",
    "for i in [0, 10, 100]:\n",
    "    plot_forecast(preds, trues, var_indices=1, sample_idx=i)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
